{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2024 Google LLC.\n",
        "SPDX-License-Identifier: Apache-2.0\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ],
      "metadata": {
        "id": "mTHqz8SOxhLw"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Demo:** Language Model Predictive Control\n",
        "\n",
        "**What is LMPC?**\n",
        "\n",
        "Language Model Predictive Control (LMPC) is a method to improve teachability (i.e. fast adaptation to language feedback) of robot code-writing LLMs. One key observation is that when human-robot interactions (HRI) are formulated as a partially observable Markov decision process (POMDP, in which human language inputs are observations, and robot code outputs are actions), then\n",
        "training an LLM to autoregressively complete previous interactions\n",
        "can be viewed as training a transition dynamics model -- that can be\n",
        "combined with classic robotics techniques such as model predictive\n",
        "control (MPC) to discover shorter paths to preferred outcomes (also predicted by the model). Specifically, LMPC fine-tunes an LLM to predict imagined future rollouts of language-based human-robot interactions -- then at inference time, samples multiple futures (with non-zero decoding temperature) to search for the best one and take the next action (i.e., receding horizon control as a decoding strategy).\n",
        "\n",
        "This is an open-source implementation of the work: [\"Learning to Learn Faster from Human Feedback with Language Model Predictive Control\"](https://robot-teaching.github.io/)\n",
        "\n",
        "**What's in this notebook?**\n",
        "\n",
        "Fine-tune LLMs with LMPC to improve teachability for simple 2D goal-driven navigation with simulated language feedback.\n",
        "\n",
        "**Important:** this notebook was written for **illustrative purposes** only (to show one way how LMPC can be implemented). This toy environment is far from perfect in terms of eliciting the strengths of LMPC, and over-simplifies the HRI setting. In fact, LMPC can be quite sensitive to small tweaks to the environment, and the performance improvements are relatively marginal. Creating meaningful toy environments for HRI is still an open problem, that we leave for future extensions. Note that the code in this notebook was also written for hackability rather than speed.\n",
        "\n",
        "### **Quick Start:**\n",
        "\n",
        "**Step 1.** Register for an [OpenAI API key](https://openai.com/blog/openai-api/) to use GPT-3.5 and enter it below\n",
        "\n",
        "**Step 2.** Menu > Runtime > Run all (takes about 40 mins to complete everything)"
      ],
      "metadata": {
        "id": "TDnm4lG7inYb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "openai_api_key = \"\""
      ],
      "metadata": {
        "id": "_rfwepbxgHn6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Setup"
      ],
      "metadata": {
        "id": "9EJIpYdbDQZk"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x4HnsvEmR8mL"
      },
      "outputs": [],
      "source": [
        "!pip install openai==1.12.0\n",
        "\n",
        "import time\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from openai import OpenAI\n",
        "\n",
        "client = OpenAI(api_key=openai_api_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Helper Functions"
      ],
      "metadata": {
        "id": "JWnae8XgDUUi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "BASE_MODEL = \"gpt-3.5-turbo-1106\"  # This is also the model we are fine-tuning.\n",
        "\n",
        "def LLM(messages, model=BASE_MODEL, stop=None, max_tokens=256, temperature=0.3):\n",
        "  responses = client.chat.completions.create(model=model, messages=messages, max_tokens=max_tokens, temperature=temperature, stop=stop)\n",
        "  text = responses.choices[0].message.content\n",
        "  return text\n",
        "\n",
        "# Test LLM.\n",
        "LLM([{\"role\": \"user\", \"content\": \"hello world!\"}])"
      ],
      "metadata": {
        "id": "H5LvgeMlywdW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Environment**\n",
        "\n",
        "This toy 2D goal-driven navigation environment starts with a language instruction (e.g. \"navigate to the bottom left\"), that the agent (LLM) takes as input, to then output simple navigation code (as a sequence of primitive functions such as `move_left()`). If the code does not enable the agent to reach the goal, then the environment subsequently simulates language feedback at every timestep. The objective of LMPC is to enable the agent to minimize the average number of language inputs (e.g. corrections) before successfully reaching the goal.\n",
        "\n",
        "Human guidance is imperfect in real HRI settings, and this is modeled in the toy environment as noise on the feedback simulator. This notebook shows how to improve LMPC to respond to navigation feedback with top-user conditioning, which (i) identifies top users (by performance on training tasks), (ii) groups their data together with a special username “top-user,” then (iii) conditions inference-time LMPC rollouts on this special username (i.e., assume everyone is a top-user). In the toy setting, (as of Feb 2024) fine-tuning `gpt-3.5-turbo-1106` reduces the average number of language inputs before success from 2.4 to 2.2, and improves the success rate of held out test tasks."
      ],
      "metadata": {
        "id": "u1eDp5f0CkgS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "MAX_NUM_FEEDBACK = 5  # Number of human language inputs (not including initial instruction).\n",
        "TOP_USERS = [\"user-noise-0.0\"]\n",
        "TEST_TASKS = [\"top left\", \"left\"]"
      ],
      "metadata": {
        "id": "VPw6pdGC2dC-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SimpleParticleNav:\n",
        "\n",
        "  def __init__(self, is_real_user=False):\n",
        "    self.goals = {'center':   (0, 0),\n",
        "                  'bottom left': (-0.5, -0.5),\n",
        "                  'bottom': (0, -0.5),\n",
        "                  'bottom right': (0.5, -0.5),\n",
        "                  'right': (0.5, 0),\n",
        "                  'top right': (0.5, 0.5),\n",
        "                  'top':      (0, 0.5),\n",
        "                  # Test tasks (compositional generalization):\n",
        "                  'top left': (-0.5, 0.5),\n",
        "                  'left':     (-0.5, 0),}\n",
        "    self.is_real_user = is_real_user\n",
        "\n",
        "  def reset(self, noise=0):\n",
        "    self.noise = noise  # Noise on the feedback.\n",
        "    self.terminated = False\n",
        "    self.num_steps = 0\n",
        "\n",
        "    # Build a valid random initialization that is not immediately success.\n",
        "    while True:\n",
        "      self.init_name = np.random.choice(list(self.goals.keys()))\n",
        "      self.agent_pos = np.float32(self.goals[self.init_name])\n",
        "      self.goal_name = np.random.choice(list(self.goals.keys()))\n",
        "      self.goal_pos = np.float32(self.goals[self.goal_name])\n",
        "      if self.is_success() is None:  # None means episode is still ongoing.\n",
        "        break\n",
        "\n",
        "    self.path = self.agent_pos.copy()\n",
        "    state = f\"new episode: the agent is at the {self.init_name}\"\n",
        "    instruction = f\"user: navigate to the {self.goal_name} goal\"\n",
        "    return state, instruction, self.is_success()\n",
        "\n",
        "  def step(self, act):\n",
        "    self.agent_pos += act\n",
        "    self.agent_pos = np.clip(self.agent_pos, -1, 1)  # Note: nonlinearity.\n",
        "    self.path = np.vstack((self.path, self.agent_pos.copy()))\n",
        "    self.num_steps += 1\n",
        "    return self.get_feedback(), self.is_success()\n",
        "\n",
        "  def get_feedback(self):\n",
        "    if self.is_real_user:\n",
        "      self.render()\n",
        "      return f\"user: {input(f'Enter next instruction:')}\"\n",
        "\n",
        "    # Random noise on where the user thinks the goal is.\n",
        "    noise_goal_pos = self.goal_pos if np.random.rand() > self.noise else self.goals[np.random.choice(list(self.goals.keys()))]\n",
        "    if noise_goal_pos[1] > self.agent_pos[1]:\n",
        "      return \"user: go north\"  # Slightly less trivial mapping of feedback to code.\n",
        "    elif noise_goal_pos[1] < self.agent_pos[1]:\n",
        "      return \"user: go south\"\n",
        "    elif noise_goal_pos[0] > self.agent_pos[0]:\n",
        "      return \"user: go east\"\n",
        "    elif noise_goal_pos[0] < self.agent_pos[0]:\n",
        "      return \"user: go west\"\n",
        "    else:\n",
        "      return \"user: do not move\"\n",
        "\n",
        "  def is_success(self):\n",
        "    if np.all(np.isclose(self.goal_pos, self.agent_pos)):\n",
        "      return True\n",
        "    if self.num_steps > MAX_NUM_FEEDBACK:\n",
        "      return False\n",
        "\n",
        "  def render(self):\n",
        "    plt.scatter(self.goal_pos[0], self.goal_pos[1], c=\"tab:green\", s=300, alpha=0.5)\n",
        "    plt.plot([-1, -1, 1, 1, -1],\n",
        "             [-1, 1, 1, -1, -1], c=\"#dddddd\", linewidth=2)\n",
        "    plt.plot(self.path[:, 0], self.path[:, 1], c=\"tab:blue\", alpha=0.5, linewidth=2, linestyle='dashed')\n",
        "    plt.scatter(self.path[-1, 0], self.path[-1, 1], c=\"tab:blue\", s=50, alpha=0.5, zorder=10)\n",
        "    plt.axis('equal')\n",
        "    plt.axis('off')\n",
        "    plt.show()\n",
        "\n",
        "\n",
        "class ParticleAgent:\n",
        "  \"\"\"Simple helper class that holds the next action from exec().\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    self.delta = np.float32([0, 0])\n",
        "\n",
        "  def move_up(self):\n",
        "    self.delta[1] += 0.5\n",
        "\n",
        "  def move_down(self):\n",
        "    self.delta[1] -= 0.5\n",
        "\n",
        "  def move_left(self):\n",
        "    self.delta[0] -= 0.5\n",
        "\n",
        "  def move_right(self):\n",
        "    self.delta[0] += 0.5\n",
        "\n",
        "  def wait(self):\n",
        "    pass\n",
        "\n",
        "\n",
        "def code_to_actions(code):\n",
        "  agent = ParticleAgent()\n",
        "  try:\n",
        "    exec(code)\n",
        "  except:\n",
        "    print(\"Invalid code.\")\n",
        "    pass\n",
        "  return agent.delta"
      ],
      "metadata": {
        "id": "Yql-0CMpnEco"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Prompt**\n",
        "\n",
        "The policy is driven by a prompt passed as input to an instruction-tuned code-writing LLM. The prompt contains a preamble that describes the environments,  API functions available to the agent, and an example of interacting with a user.\n",
        "\n",
        "Note this assumes that the LLM (with a prompt) already has some base level performance on code-writing to achieve non-zero success during data collection. Noise is only added to the user feedback from the environment -- future extensions of this notebook may consider adding noise on the agent's actions as well to simulate imperfect human-robot interactions."
      ],
      "metadata": {
        "id": "S2v9R5hLB5fw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "PROMPT = [{\"role\": \"system\", \"content\": \"You are an agent that navigates to a goal location, and can call the following functions: move_up(), move_down(), move_left(), move_right(), wait(). Please write code according to user feedback to navigate to the goal. For example:\"},\n",
        "          {\"role\": \"user\", \"content\": \"new episode: the agent is at the top right\"},\n",
        "          {\"role\": \"assistant\", \"content\": \"user: navigate to the bottom\"},\n",
        "          {\"role\": \"assistant\", \"content\": \"agent.move_down()\\nagent.move_left()\"},\n",
        "          {\"role\": \"assistant\", \"content\": \"user: go south\"},\n",
        "          {\"role\": \"assistant\", \"content\": \"agent.move_down()\"}]"
      ],
      "metadata": {
        "id": "yN5cyuUM55Xd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Data Collection (2 mins)**\n",
        "\n",
        "Collect 25 episodes (chat sessions) with each user of varying noise.\n",
        "\n",
        "Note we are prompting the chat completion APIs with every message as coming from \"assistant.\" As of Feb 2024, this is needed so we can fine-tune the model to predict \"what a user might say.\" Doing so otherwise does not work (it is possible that OpenAI model training incurs different completion losses on tokens from \"assistant\" vs \"user\")."
      ],
      "metadata": {
        "id": "e1n_o6W8EHIA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def collect_data(policy, name=\"user\", is_real_user=False):\n",
        "  data = {\"user\": [], \"task\": [], \"session\": [], \"chat_length\": [], \"success\":[]}\n",
        "\n",
        "  env = SimpleParticleNav(is_real_user=is_real_user)\n",
        "  for noise in [0, 0.3, 0.6, 0.8]:\n",
        "    user = f\"user-noise-{noise:.1f}\"\n",
        "    for _ in range(25):\n",
        "\n",
        "      # New episode (chat session).\n",
        "      episode = []  # Tracks messages in the current episode.\n",
        "      state, feedback, success = env.reset(noise=noise)\n",
        "      print(f\"\\nUser: {user}:\\nDescription: {state}\")\n",
        "      episode.append({\"role\": \"user\", \"content\": state})\n",
        "\n",
        "      # Dialogue between agent and user.\n",
        "      while success is None:\n",
        "        feedback = feedback.replace(\"user:\", f\"{name}:\")\n",
        "        print(f\"  {feedback}\")\n",
        "\n",
        "        # Policy.\n",
        "        episode.append({\"role\": \"assistant\", \"content\": feedback})\n",
        "        code = policy(episode)\n",
        "        print(f\"  code:\\t{code}\".replace(\"\\n\", \"\\n\\t\"))\n",
        "        episode.append({\"role\": \"assistant\", \"content\": code})\n",
        "        act = code_to_actions(code)\n",
        "\n",
        "        # Step environment.\n",
        "        feedback, success = env.step(act)\n",
        "\n",
        "      data[\"user\"].append(user)\n",
        "      data[\"task\"].append(env.goal_name)\n",
        "      data[\"session\"].append(episode)\n",
        "      data[\"chat_length\"].append(env.num_steps)\n",
        "      data[\"success\"].append(success)\n",
        "\n",
        "      print(\"Success:\", success)\n",
        "      env.render()\n",
        "\n",
        "  return data\n",
        "\n",
        "\n",
        "np.random.seed(42)\n",
        "train_data = collect_data(policy=lambda x: LLM(PROMPT + x), name=\"user\")"
      ],
      "metadata": {
        "id": "Iqy_jEGsdT5x"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "###Show Metrics\n",
        "\n",
        "Show some statistics from the training data: avg success rates and chat lengths."
      ],
      "metadata": {
        "id": "3ejpVu0sFWAY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def task_to_split(task):\n",
        "  return 'test' if task in TEST_TASKS else 'train'\n",
        "\n",
        "def show_metrics(data):\n",
        "  # Show overall success and chat length metrics.\n",
        "  print(\"Overall:\", f\"\\n\\t\\t\\tSuccess: {np.mean(data['success'])*100:.0f}%\", f\"\\tAvg Chat Length: {np.mean([l for l, s in zip(data['chat_length'], data['success']) if s]):.1f}\")\n",
        "\n",
        "  # Show success and chat length metrics by user.\n",
        "  user_to_success = {user: [] for user in set(data[\"user\"])}\n",
        "  user_to_chat_length = {user: [] for user in set(data[\"user\"])}\n",
        "  for user, chat_length, success in zip(data[\"user\"], data[\"chat_length\"], data[\"success\"]):\n",
        "    user_to_success[user].append(success)\n",
        "    if success:\n",
        "      user_to_chat_length[user].append(chat_length)\n",
        "  print(\"By users:\")\n",
        "  for user in sorted(list(user_to_success.keys())):\n",
        "    print(\"  \", user, f\"\\tSuccess: {np.mean(user_to_success[user])*100:.0f}%\", f\"\\tAvg Chat Length: {np.mean(user_to_chat_length[user]):.1f}\")\n",
        "\n",
        "  # Show success and chat length metrics by train:test task split.\n",
        "  split_to_success = {\"train\": [], \"test\": []}\n",
        "  split_to_chat_length = {\"train\": [], \"test\": []}\n",
        "  for task, chat_length, success in zip(data[\"task\"], data[\"chat_length\"], data[\"success\"]):\n",
        "    if task in TEST_TASKS:\n",
        "      split_to_success[\"test\"].append(success)\n",
        "      if success:\n",
        "        split_to_chat_length[\"test\"].append(chat_length)\n",
        "    else:\n",
        "      split_to_success[\"train\"].append(success)\n",
        "      if success:\n",
        "        split_to_chat_length[\"train\"].append(chat_length)\n",
        "  print(\"By tasks split:\")\n",
        "  for split in [\"train\", \"test\"]:\n",
        "    print(\"  \", split, f\"\\t\\tSuccess: {np.mean(split_to_success[split])*100:.0f}%\", f\"\\tAvg Chat Length: {np.mean(split_to_chat_length[split]):.1f}\")\n",
        "\n",
        "  # Show success and chat length split by both (user and test/train):\n",
        "  user_to_splits = {user: {} for user in set(data[\"user\"])}\n",
        "  for user in user_to_splits.keys():\n",
        "    user_to_splits[user] = {\"train\": {\"chat_length\": [], \"success\": []}, \"test\": {\"chat_length\": [], \"success\": []}}\n",
        "  for user, task, chat_length, success in zip(data[\"user\"], data[\"task\"], data[\"chat_length\"], data[\"success\"]):\n",
        "    split = task_to_split(task)\n",
        "    user_to_splits[user][split]['success'].append(success)\n",
        "    if success:\n",
        "      user_to_splits[user][split]['chat_length'].append(chat_length)\n",
        "  print (\"By users and splits:\")\n",
        "  for user in sorted(list(user_to_splits.keys())):\n",
        "    for split in [\"train\", \"test\"]:\n",
        "      print(\"  \", f\"{user} ({split})\", f\"\\t\\tSuccess: {np.mean(user_to_splits[user][split]['success'])*100:.0f}%\", f\"\\tAvg Chat Length: {np.mean(user_to_splits[user][split]['chat_length']):.1f}\")\n",
        "    print()\n",
        "\n",
        "show_metrics(train_data)"
      ],
      "metadata": {
        "id": "DOP4uFqqfyRE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Train LMPC (15 mins)**\n",
        "\n",
        "Fine-tunes an LLM (GPT-3.5) for LMPC using OpenAI API."
      ],
      "metadata": {
        "id": "7Ysb8uBAUBX0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Prepare Data\n",
        "\n",
        "Formats episodes (chat sessions) as .jsonl files that the OpenAI API expects for fine-tuning.\n",
        "\n",
        "Note here that we are also fine-tuning the LLM to be user conditioned, and re-labeling top performing users (e.g. with noise) as \"experts.\" We will later use this during inference time to drive performance improvements."
      ],
      "metadata": {
        "id": "da-PUVw1e5VO"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import json\n",
        "import copy\n",
        "\n",
        "train_sessions = []\n",
        "test_sessions = []\n",
        "for user, task, messages, chat_length, success in zip(train_data[\"user\"], train_data[\"task\"], train_data[\"session\"], train_data[\"chat_length\"], train_data[\"success\"]):\n",
        "  name = \"expert\" if user in TOP_USERS else user  # Re-label top users as \"experts.\"\n",
        "  session = copy.deepcopy(PROMPT)\n",
        "  for m in messages:\n",
        "    m[\"content\"] = m[\"content\"].replace(\"user:\", f\"{name}:\")\n",
        "    session.append(m)\n",
        "  session.append({\"role\": \"assistant\", \"content\": f\"success: {success}\"})\n",
        "  test_sessions.append(session) if task in TEST_TASKS else train_sessions.append(session)\n",
        "\n",
        "# Save to .jsonl files.\n",
        "with open('train-lmpc.jsonl', 'w') as f:\n",
        "  for i, session in enumerate(train_sessions):\n",
        "    print(session)\n",
        "    json.dump({\"messages\": session}, f)\n",
        "    if i < len(train_sessions) - 1:\n",
        "      f.write('\\n')\n",
        "  print(\"Train dataset size:\", len(train_sessions))\n",
        "\n",
        "with open('test-lmpc.jsonl', 'w') as f:\n",
        "  for i, session in enumerate(test_sessions):\n",
        "    print(session)\n",
        "    json.dump({\"messages\": session}, f)\n",
        "    if i < len(test_sessions) - 1:\n",
        "      f.write('\\n')\n",
        "  print(\"Test dataset size:\", len(test_sessions))"
      ],
      "metadata": {
        "id": "1F0V4lqjGASE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Fine-tune LLM\n",
        "\n",
        "Uploads training data to OpenAI then starts a fine-tuning job."
      ],
      "metadata": {
        "id": "gCe4tqbLHbZL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Upload training data.\n",
        "train_file = client.files.create(file=open(\"train-lmpc.jsonl\", \"rb\"), purpose=\"fine-tune\")\n",
        "test_file = client.files.create(file=open(\"test-lmpc.jsonl\", \"rb\"), purpose=\"fine-tune\")\n",
        "\n",
        "# Start fine-tuning job.\n",
        "ftjob = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=test_file.id, model=\"gpt-3.5-turbo-1106\")\n",
        "\n",
        "# Track fine-tuning job until complete.\n",
        "while ftjob.status != \"succeeded\":\n",
        "  ftjob = client.fine_tuning.jobs.retrieve(ftjob.id)\n",
        "  print(time.ctime(), \"Status:\", ftjob.status)\n",
        "  time.sleep(30)\n",
        "fine_tuned_model = ftjob.fine_tuned_model"
      ],
      "metadata": {
        "id": "SngSHR4GEKDZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Test the fine-tuned model."
      ],
      "metadata": {
        "id": "VSPjexLeQo-E"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "LLM([{\"role\": \"user\", \"content\": \"hello world!\"}], fine_tuned_model)"
      ],
      "metadata": {
        "id": "xqYyanNT6w2R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **LMPC**\n",
        "\n",
        "Simple illustrative implementation of LMPC.\n",
        "\n",
        "**Important:** Our HRI experiments show that LMPC improves LLM adaptation to language feedback for robot code-writing, but that is with real humans. It is worthwhile to think about how LMPC improves performance in a toy setting like this one. There are at least 2 key aspects:\n",
        "\n",
        "* Top user conditioning allows LMPC to generate (and search among) future rollouts with less feedback noise at inference-time.\n",
        "\n",
        "* LMPC is also a decoding strategy that benefits from sampling.\n",
        "  * One can also run top user conditioning without MPC (though we observe slightly worse performance)."
      ],
      "metadata": {
        "id": "Ix5_mOnKjnTn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def LMPC(episode):\n",
        "\n",
        "  # Do LMPC rollouts.\n",
        "  num_rollouts = 4\n",
        "  best_rollout = None\n",
        "  rollouts = []\n",
        "  for _ in range(num_rollouts):\n",
        "    rollout = PROMPT + episode\n",
        "    num_turns = MAX_NUM_FEEDBACK + 1  # Include initial language instructions.\n",
        "    num_msgs = 2 * num_turns  # Each chat turn has 2 messages.\n",
        "    max_preds = num_msgs - len(episode) + 1  # Number of steps into the future.\n",
        "    for _ in range(max_preds):\n",
        "      text = LLM(rollout, fine_tuned_model)\n",
        "      rollout.append({\"role\": \"assistant\", \"content\": text})\n",
        "      if \"success: True\" in text:\n",
        "        rollouts.append(rollout[len(PROMPT)+len(episode):])  # Only look into the future.\n",
        "        break\n",
        "    best_rollout = rollout[len(PROMPT)+len(episode):]  # Defaults to the last rollout.\n",
        "\n",
        "  # Find shortest path to success and take next action.\n",
        "  if len(rollouts) > 0:\n",
        "    for rollout in rollouts:\n",
        "      if len(rollout) < len(best_rollout):\n",
        "        best_rollout = rollout\n",
        "  code = best_rollout[0][\"content\"]\n",
        "  return code"
      ],
      "metadata": {
        "id": "gnQJ00CzQ07s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Evals (20 mins)**\n",
        "\n",
        "Evaluates fine-tuned LMPC using the same data collection protocol.\n",
        "\n",
        "Top user conditioned LMPC inference assumes each user is a top user."
      ],
      "metadata": {
        "id": "hIxGqbb-I7N6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "np.random.seed(1234)\n",
        "eval_data = collect_data(policy=LMPC, name=\"expert\")  # Assume user is an \"expert\" (top user)."
      ],
      "metadata": {
        "id": "Dgc6ijbaHDU4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"BEFORE:\")\n",
        "show_metrics(train_data)\n",
        "print(\"\\nAFTER:\")\n",
        "show_metrics(eval_data)"
      ],
      "metadata": {
        "id": "MbI3ORecd9fj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Playground**\n",
        "\n",
        "Play with the fine-tuned LLM or base model by substituting into the environment as a real human that provides feedback."
      ],
      "metadata": {
        "id": "1lECACM5xCrW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "finetuned_model = True # @param {type:\"boolean\"}\n",
        "\n",
        "if finetuned_model:\n",
        "  tmp_data = collect_data(policy=LMPC, name=\"user\", is_real_user=True)\n",
        "else:\n",
        "  tmp_data = collect_data(policy=lambda x: LLM(PROMPT + x), name=\"user\", is_real_user=True)"
      ],
      "metadata": {
        "id": "lMqXjSUfVatq"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}